Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix biasScaleShape of GroupNormalizationV21 to support ranks > 4 #3030

Merged

Conversation

jorickert
Copy link
Collaborator

Before this PR the oneDimShape assumed a spacial rank of two, which is only correct for rank==4.

@hamptonm1
Copy link
Collaborator

@jorickert Hello!! Can you direct me to the link or document which validates your findings please? I just want to have a better understanding. Thank you!

@jorickert
Copy link
Collaborator Author

jorickert commented Dec 17, 2024

@hamptonm1
According to the onnx spec the input for GroupNormV21 has the dimensions (N, C, D1,D2, Dn), scale and bias have the dimension C

The formula for GroupNorm is:
y = scale * (x - mean) / sqrt(variance + epsilon) + bias where the mean and variance are computed per instance per group of channels. The formula for LayerNorm is generally the same, the main difference (in onnx) being that it allows the selection of axes for the mean and variance.

Internally, the GroupNorm reshapes the input to (N, G, C // G, D1, D2, Dn) and then performs a LayerNorm like operation on it, with the reduction axes being C // G to Dn. This can be seen in the GroupNorm paper https://arxiv.org/pdf/1803.08494 figure 3.

The decomposition/conversion mimics this:

  1. Manually reshape the input to (N, G, C // G, D1, D2, Dn)
  2. Perform a LayerNorm with axes= C // G to Dn
  3. Manually reshape to (N, C, D1,D2, Dn)
    Additionally, it is required to reshape the scale and bias. The input scale and bias have shape C.
    To make them compatible with the LayerNorm and broadcasting, the following reshape needs to be done:
    C -> (G, C // G, 1, 1, 1) where the number of once equals the number of spacial dimensions.
    Without my PR, the reshape is to (G, C // G, 1, 1), no matter how many spacial dimensions exist.

While writing this down, I realized, that the lit tests are still wrong for GroupNormV21. The input scale and bias have the size C//G = 2 (which is correct for GroupNormV18) , but should be C = 4 . I will fix this

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let me know if you need me to merge it (and if so, when you are ready to do so),

Copy link
Collaborator

@hamptonm1 hamptonm1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay it works for me then!

@jorickert jorickert force-pushed the upstream.jrickert.groupnormv21 branch from aaa9a68 to 0c65cec Compare December 30, 2024 18:52
@jorickert
Copy link
Collaborator Author

@AlexandreEichenberger It would be nice if you could merge this. Thank you

jorickert and others added 2 commits December 30, 2024 13:54
For V21 the scale and bias have size C, not C//G

Signed-off-by: Rickert, Jonas <[email protected]>
@jorickert jorickert force-pushed the upstream.jrickert.groupnormv21 branch from 0c65cec to 044c80c Compare December 30, 2024 18:54
@AlexandreEichenberger AlexandreEichenberger merged commit c0d447f into onnx:main Jan 2, 2025
7 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16093 [push] Fix biasScaleShape of Gr... started at 13:08

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16091 [push] Fix biasScaleShape of Gr... started at 12:08

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15120 [push] Fix biasScaleShape of Gr... started at 13:25

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16093 [push] Fix biasScaleShape of Gr... passed after 1 hr 29 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16091 [push] Fix biasScaleShape of Gr... passed after 1 hr 30 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15120 [push] Fix biasScaleShape of Gr... passed after 2 hr 26 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants